import warnings
import argparse

import os
from pathlib import Path

import omegaconf
import yaml
import time
import shutil

import dreamerv3
from baselines.qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from dreamerv3.agent import ImagActorCritic
from dreamerv3 import embodied
from dreamerv3.embodied.core.goal_sampler import GoalSampler, GoalSamplerCyclic
from eval import eval_data_ours
from evaluation.eval_dcg import reevaluate_saved_data_dcg
from evaluation.eval_smerl import evaluate_and_save_smerl
from utils import Config, get_env, get_argv_from_config

warnings.filterwarnings("ignore", ".*truncated to dtype int32.*")


NUM_EVALS_DEFAULT = 10

def get_args():
  parser = argparse.ArgumentParser()
  parser.add_argument('-p', '--path', type=str, required=True, help='Path to the directory containing all subdirectories of the experiments to evaluate.')
  parser.add_argument('-n', '--num-evals', type=int, help='Number of evaluations to run per goal.', default=NUM_EVALS_DEFAULT)
  args = parser.parse_args()
  return args


def get_resolution(config):
  argv = [
            f"--task={config['task']['value']}",
            f"--feat={config['feat']['value']}",
            # TODO: is this necessary?
            f"--goal.resolution={config['goal']['value']['resolution']}",

            f"--envs.amount=1",
            f"--backend={config['backend']['value']}",
            ]

  # Create config
  logdir = '.' 
  config_defaults = embodied.Config(dreamerv3.configs["defaults"])
  config_defaults = config_defaults.update(dreamerv3.configs["brax"])
  config_defaults = config_defaults.update({
    "logdir": logdir,
    "run.train_ratio": 32,
    "run.log_every": 60,  # Seconds
    "batch_size": 16,
  })
  # argv = get_argv_from_config(config)
  config = embodied.Flags(config_defaults).parse(argv=argv)

  # Create environment
  env = get_env(config, mode="train")

  # Create goal sampler
  resolution = ImagActorCritic.get_resolution(env.feat_space, config)
  return resolution

def get_subdirectories(path):
  path = Path(path)
  return [x for x in path.iterdir() if x.is_dir()]

def eval_for_algorithm(path_saving_evaluations, task_name, algo_name, replication_path, number_evaluations_per_goal, resolution):
  replication_name = replication_path.name

  _path_save_replication_data = path_saving_evaluations / task_name / algo_name / replication_name
  _path_save_replication_data.mkdir(parents=True, exist_ok=True)

  if algo_name in ("ours", "ours_sep_skill", "ours_fixed_lambda", "uvfa"):
    eval_data_ours(replication_path, number_evaluations_per_goal, path_saving_evaluations=_path_save_replication_data)
  elif algo_name == "dcg_me":
    reevaluate_saved_data_dcg(replication_path, _path_save_replication_data, num_reevals=number_evaluations_per_goal,resolution=resolution)
  elif algo_name == "smerl":
    evaluate_and_save_smerl(replication_path, _path_save_replication_data, number_evaluations_per_goal, is_reversed=False, resolution=resolution)
  elif algo_name == "smerl_reverse":
    evaluate_and_save_smerl(replication_path, _path_save_replication_data, number_evaluations_per_goal, is_reversed=True, resolution=resolution)
  else:
    raise ValueError(f"Unknown algo name {algo_name}")
  
def check_if_config_exist(path_saving_evaluations, task_name, algo_name, replication_path):
  replication_name = replication_path.name
  _path_save_replication_data = Path(path_saving_evaluations) / task_name / algo_name / replication_name
  _path_save_replication_data.mkdir(parents=True, exist_ok=True)

  path_hydra = _path_save_replication_data / "config_hydra.yaml"
  path_wandb = _path_save_replication_data / "config_wandb.yaml"

  return path_hydra.exists() and path_wandb.exists()


def copy_config(path_saving_evaluations, task_name, algo_name, replication_path):
  replication_name = replication_path.name

  path_results = Path(replication_path)
  _path_save_replication_data = Path(path_saving_evaluations) / task_name / algo_name / replication_name
  _path_save_replication_data.mkdir(parents=True, exist_ok=True)
  path_config_hydra = path_results / ".hydra" / "config.yaml"
  path_config_wandb = path_results / "wandb" / "latest-run" / "files" / "config.yaml"

  shutil.copy(path_config_hydra, _path_save_replication_data / "config_hydra.yaml")
  shutil.copy(path_config_wandb, _path_save_replication_data / "config_wandb.yaml")

def eval_bulk(path_results, number_evaluations_per_goal):
  path_results = Path(path_results)

  # Create directory for saving evaluations
  path_results_name = path_results.name
  parent_directory = path_results.parents[0]

  path_saving_evaluations = parent_directory / f"{path_results_name}_pre_analysis"
  # print("path_saving_evaluations.glob('**/*')", str(path_saving_evaluations))
  algos_paths = get_subdirectories(path_results)
  print("Visiting algo paths:", '\n'.join([str(x) for x in algos_paths]))

  all_valid_algos = (
    'ours',
    'ours_sep_skill',
    'ours_fixed_lambda',
    'uvfa',
    'dcg_me',
    'smerl',
    'smerl_reverse',
  )

  resolutions_dict_per_task = dict()

  print("Collecting resolutions")
  for _algo_path in algos_paths:
    task_paths = get_subdirectories(_algo_path)
    algo_name = _algo_path.name
    if algo_name not in ('ours', 'ours_sep_skill', 'ours_fixed_lambda', 'uvfa',):
      continue
    for _task_path in task_paths:
      print("task", _task_path)
      task_name = _task_path.name
      replications_paths = get_subdirectories(_task_path)
      resolutions_dict_per_task[task_name] = []
      for _replication_path in replications_paths:
        print("Loading resolution for", _replication_path)
        try:
          path_config = _replication_path / "wandb" / "latest-run" / "files" / "config.yaml"
          path_summary = _replication_path / "wandb" / "latest-run" / "files" / "wandb-summary.json"
          # with open(path_config, 'r') as f:
          #   config = yaml.safe_load(f)
          with open(path_summary, 'r') as f:
            import json
            summary = json.load(f)
          resolution = summary["goal/resolution_in_practice"]
          resolutions_dict_per_task[task_name].append(resolution)
          
        except:
          print(f"Skipping {_replication_path} because of missing config")

  for task_name, resolutions in resolutions_dict_per_task.items():
    print(f"Task {task_name} has resolutions {resolutions}")
    assert all(resolution == resolutions[0] for resolution in resolutions), f"Task {task_name} has different resolutions: {resolutions}"

  resolutions_dict_per_task = {task_name: resolutions[0] for task_name, resolutions in resolutions_dict_per_task.items()}

  # exit()
  for _algo_path in algos_paths:

    task_paths = get_subdirectories(_algo_path)
    print("Visiting task paths:", '\n'.join([str(x) for x in task_paths]))
    algo_name = _algo_path.name

    if algo_name not in all_valid_algos:
      print(f"Skipping algo {algo_name}")
      continue

    for _task_path in task_paths:
      print("Visiting replication paths:", '\n'.join([str(x) for x in get_subdirectories(_task_path)]))
      replications_paths = get_subdirectories(_task_path)
      task_name = _task_path.name
      for _replication_path in replications_paths:
        # If exception occurs, skip this replication and put an error message in file with the time
        if check_if_config_exist(path_saving_evaluations, task_name, algo_name, _replication_path):
          print(f"Skipping {_replication_path} because it has already been evaluated")
          continue

        try:
          resolution = resolutions_dict_per_task[task_name]
          eval_for_algorithm(path_saving_evaluations=path_saving_evaluations,
                             task_name=task_name,
                             algo_name=algo_name,
                             replication_path=_replication_path,
                             number_evaluations_per_goal=number_evaluations_per_goal,
                             resolution=resolution)
          copy_config(path_saving_evaluations=path_saving_evaluations,
                      task_name=task_name,
                      algo_name=algo_name,
                      replication_path=_replication_path)

        except Exception as e:
          import traceback
          print(f"Exception occurred for {_replication_path}")
          # get timestamp
          timestamp = time.strftime("%Y%m%d-%H%M%S")

          with open(path_saving_evaluations / "errors.txt", "a") as file:
            file.write(f"{timestamp} - Exception occurred for {_replication_path}\n")
            file.write(traceback.format_exc())
            file.write("\n\n")


def main():
  args = get_args()
  eval_bulk(args.path, args.num_evals)


if __name__ == '__main__':
  main()
